package difftest

import com.sun.jna.{Memory, Native, Pointer}
import difftest.DifftestCoreNative.RefLib
import utils.logger

import java.io.{File, FileInputStream}

object DifftestCore {
  val IMAGE_MAX_SIZE = 0x80000
  // val IMAGE_DEFAULT = "/win32-x86-64/bin/hello-riscv32-nemu.bin"
  val IMAGE_DEFAULT = "/src/test/resources/bin/hello-riscv32-nemu.bin"
  val MEM_BASE = 0x80000000
  var ref: Option[RefLib] = None

  class CPUState {
    val gprSize = 32
    var gpr: Array[Int] = Array[Int](gprSize)
    var pc: Int = 0
    val sizeUnit: Int = Native.getNativeSize(classOf[Int])
    val dst = new Memory(sizeUnit * 33 * 4)

    def load(): Unit = {
      gpr = dst.getIntArray(0, 32)
      // Offset's unit is byte
      pc = dst.getInt(32 * sizeUnit)
    }

    def dump(): Unit = {
      for (i <- 0 until gprSize * sizeUnit) {
        dst.setInt(i, gpr(i))
      }
      dst.setInt(32 * sizeUnit, pc)
    }
  }

  val cpu = new CPUState

  def regcpy(direction: Boolean): Unit = {
    def exec_regcpy() = ref.get.difftest_regcpy(cpu.dst, direction, copy_csr = false)

    if (direction == RefLib.TO_REF) {
      cpu.dump()
      exec_regcpy()
    } else {
      exec_regcpy()
      cpu.load()
    }
  }

  def memcpy(addr: Int, buf: Pointer, n: Int, direction: Boolean): Unit = ref.get.difftest_memcpy(addr, buf, n, direction)

  def exec(n: Int): Unit = ref.get.difftest_exec(n)

  def raise_intr(no: Int): Unit = ref.get.difftest_raise_intr(no)

  def init(port: Int): Unit = {
    ref = Some(RefLib.init())
    ref.get.difftest_init(port)
  }

  def release(): Unit = {
    ref.get.difftest_release()
    System.gc()
    RefLib.release()
    ref = None
  }

  def loadImage(memStartAddr: Int = MEM_BASE, imageFileName: String = getClass.getResource(IMAGE_DEFAULT).getFile): Unit = {
    logger.info(s"loadImage($memStartAddr, $imageFileName)")
    val imageFile = new File(imageFileName)
    val imageDataReader = new FileInputStream(imageFile)
    val imageData = new Array[Byte](IMAGE_MAX_SIZE)
    val imageSize = imageDataReader.read(imageData)
    // logger.info(s"imageSize = $imageSize, data: ${imageData.slice(0, 10).map(x => f"0x$x%02x").mkString("Array(", ", ", ", ...)")}")
    val image = new Memory(Native.getNativeSize(classOf[Int]) * imageSize)
    for (i <- 0 until imageSize) {
      image.setByte(i, imageData(i))
    }
    DifftestCore.memcpy(memStartAddr, image, imageSize, RefLib.TO_REF)
  }
}
